package notifications

import (
	"context"
	"database/sql"
	"encoding/json"
	"github.com/olivere/elastic/v7"
)

type Repository interface {
	GetCount(ctx context.Context, userId string) (int, error)
	GetAlarms(ctx context.Context, userId string, size int, afterAlarm Alarm) ([]Alarm, error)
}

type repository struct {
	esClient          *elastic.Client
	sqlDB             *sql.DB
	alarmMessageIndex string
}

func NewRepository(esClient *elastic.Client, sqlDB *sql.DB, alarmMessageIndex string) Repository {
	return &repository{
		esClient:          esClient,
		sqlDB:             sqlDB,
		alarmMessageIndex: alarmMessageIndex,
	}
}

func (r *repository) GetCount(ctx context.Context, userId string) (int, error) {
	searchService := r.esClient.Search(r.alarmMessageIndex).Size(1000).
		Sort("timestamp", false).
		Sort("id", true)
	searchResult, err := searchService.Do(ctx)
	if err != nil {
		return 0, err
	}
	var messageId sql.NullString
	if err := r.sqlDB.QueryRow(`SELECT messageId FROM notification where userId = ?;`, userId).Scan(&messageId); err != nil && err != sql.ErrNoRows {
		return 0, err
	}
	if len(messageId.String) > 0 {
		for i, hit := range searchResult.Hits.Hits {
			if hit.Id == messageId.String {
				return i, nil
			}
		}
	}
	return len(searchResult.Hits.Hits), nil
}

func (r *repository) GetAlarms(ctx context.Context, userId string, size int, afterAlarm Alarm) ([]Alarm, error) {
	searchService := r.esClient.Search(r.alarmMessageIndex).Size(size).
		Sort("timestamp", false).
		Sort("id", true)
	if len(afterAlarm.Id) > 0 && afterAlarm.Timestamp > 0 {
		searchService.SearchAfter(afterAlarm.Timestamp, afterAlarm.Id)
	}
	searchResult, err := searchService.Do(ctx)
	if err != nil {
		return nil, err
	}
	var alarms []Alarm
	for _, hit := range searchResult.Hits.Hits {
		var alarm Alarm
		if err := json.Unmarshal(hit.Source, &alarm); err != nil {
			return nil, err
		}
		alarms = append(alarms, alarm)
	}
	if len(afterAlarm.Id) == 0 && len(alarms) > 0 {
		if _, err := r.sqlDB.Exec(
			`REPLACE INTO notification (userId, messageId) VALUES (?, ?);`,
			userId, alarms[0].Id,
		); err != nil {
			return nil, err
		}
	}
	return alarms, nil
}
